{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "cws = pickle.load(open('webqsp_scores_full_kg_fixed.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "from tqdm import tqdm\n",
    "f = open('../../data/fbwq_full/train.txt', 'r')\n",
    "triples = []\n",
    "for line in f:\n",
    "    line = line.strip().split('\\t')\n",
    "    triples.append(line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5780246/5780246 [00:26<00:00, 220886.31it/s]\n"
     ]
    }
   ],
   "source": [
    "G = nx.Graph()\n",
    "for t in tqdm(triples):\n",
    "    e1 = t[0]\n",
    "    e2 = t[2]\n",
    "    G.add_node(e1)\n",
    "    G.add_node(e2)\n",
    "    G.add_edge(e1, e2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5780246/5780246 [00:16<00:00, 343697.34it/s]\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "triples_dict = defaultdict(set)\n",
    "for t in tqdm(triples):\n",
    "    pair = (t[0], t[2])\n",
    "    triples_dict[pair].add(t[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getRelationsFromKG(head, tail):\n",
    "    return triples_dict[(head, tail)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getRelationsInPath(G, e1, e2):\n",
    "    path = nx.shortest_path(G, e1, e2)\n",
    "    relations = []\n",
    "    if len(path) < 2:\n",
    "        return []\n",
    "    for i in range(len(path) - 1):\n",
    "        head = path[i]\n",
    "        tail = path[i+1]\n",
    "        rels = list(getRelationsFromKG(head, tail))\n",
    "        relations.extend(rels)\n",
    "    return set(relations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import argparse\n",
    "import operator\n",
    "from torch.nn import functional as F\n",
    "import networkx as nx\n",
    "from collections import defaultdict\n",
    "from pruning_model import PruningModel\n",
    "from pruning_dataloader import DatasetPruning, DataLoaderPruning\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
    "\n",
    "f = open('../../data/fbwq_full/relations_all.dict', 'r')\n",
    "rel2idx = {}\n",
    "idx2rel = {}\n",
    "for line in f:\n",
    "    line = line.strip().split('\\t')\n",
    "    id = int(line[1])\n",
    "    rel = line[0]\n",
    "    rel2idx[rel] = id\n",
    "    idx2rel[id] = rel\n",
    "f.close()\n",
    "\n",
    "def process_data_file(fname, rel2idx, idx2rel):\n",
    "    f = open(fname, 'r')\n",
    "    data = []\n",
    "    for line in f:\n",
    "        line = line.strip().split('\\t')\n",
    "        question = line[0].strip()\n",
    "        #TODO only work for webqsp. to remove entity from metaqa, use something else\n",
    "        #remove entity from question\n",
    "        question = question.split('[')[0]\n",
    "        rel_list = line[1].split('|')\n",
    "        rel_id_list = []\n",
    "        for rel in rel_list:\n",
    "            rel_id_list.append(rel2idx[rel])\n",
    "        data.append((question, rel_id_list, line[0].strip()))\n",
    "    return data\n",
    "\n",
    "model = PruningModel(rel2idx, idx2rel, 0.0)\n",
    "checkpoint_file = \"../../pretrained_models/relation_matching_models/webqsp.pt\"\n",
    "model.load_state_dict(torch.load(checkpoint_file, map_location=lambda storage, loc: storage))\n",
    "\n",
    "data = process_data_file('../../data/fbwq_full/pruning_train.txt', rel2idx, idx2rel)\n",
    "dataset = DatasetPruning(data=data, rel2idx = rel2idx, idx2rel = idx2rel)\n",
    "print('Done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getHead(q):\n",
    "    question = q.split('[')\n",
    "    question_1 = question[0]\n",
    "    question_2 = question[1].split(']')\n",
    "    head = question_2[0].strip()\n",
    "    return head\n",
    "\n",
    "def get2hop(graph, entity):\n",
    "    l1 = graph[entity]\n",
    "    ans = []\n",
    "    ans += l1\n",
    "    for item in l1:\n",
    "        ans += graph[item]\n",
    "    ans = set(ans)\n",
    "    if entity in ans:\n",
    "        ans.remove(entity)\n",
    "    return ans\n",
    "\n",
    "def get3hop(graph, entity):\n",
    "    l1 = graph[entity]\n",
    "    ans = []\n",
    "    ans += l1\n",
    "    for item in l1:\n",
    "        ans += graph[item]\n",
    "    ans2 = []\n",
    "    ans2 += ans\n",
    "    for item in ans:\n",
    "        ans2 += graph[item]\n",
    "    ans2 = set(ans2)\n",
    "    if entity in ans2:\n",
    "        ans2.remove(entity)\n",
    "    return ans2\n",
    "\n",
    "def get1hop(graph, entity):\n",
    "    l1 = graph[entity]\n",
    "    ans = []\n",
    "    ans += l1\n",
    "    ans = set(ans)\n",
    "    if entity in ans:\n",
    "        ans.remove(entity)\n",
    "    return ans\n",
    "\n",
    "\n",
    "def getnhop(graph, entity, hops=1):\n",
    "    if hops == 1:\n",
    "        return get1hop(graph, entity)\n",
    "    elif hops == 2:\n",
    "        return get2hop(graph, entity)\n",
    "    else:\n",
    "        return get3hop(graph, entity)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getAllRelations(head, tail):\n",
    "    global G\n",
    "    global triples_dict\n",
    "    try:\n",
    "        shortest_length = nx.shortest_path_length(G, head, tail)\n",
    "    except:\n",
    "        shortest_length = 0\n",
    "    if shortest_length == 0:\n",
    "        return set()\n",
    "    if shortest_length == 1:\n",
    "        return triples_dict[(head, tail)]\n",
    "    elif shortest_length == 2:\n",
    "        paths = [nx.shortest_path(G, head, tail)]\n",
    "        relations = set()\n",
    "        for p in paths:\n",
    "            rels1 = triples_dict[(p[0], p[1])]\n",
    "            rels2 = triples_dict[(p[1], p[2])]\n",
    "            relations = relations.union(rels1)\n",
    "            relations = relations.union(rels2)\n",
    "        return relations\n",
    "    else:\n",
    "        return set()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def removeHead(question):\n",
    "    question = question = question.split('[')[0]\n",
    "    return question"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "# subset of questions for faster testing/tuning\n",
    "# num_for_testing = 100\n",
    "cws = pickle.load(open('webqsp_scores_full_kg_fixed.pkl', 'rb'))\n",
    "num_for_testing = len(cws)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1582/1582 [14:03<00:00,  1.87it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy is 0.6731984829329962\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# this is an alternative type of relation matching using neighbourhood\n",
    "# this was used in ablation, but params are not matching (since this notebook was\n",
    "# used for experimentation)\n",
    "# this is much faster than the algorithm mentioned in paper\n",
    "\n",
    "num_correct = 0\n",
    "for q in tqdm(cws[:num_for_testing]):\n",
    "    question = q['question']\n",
    "    question_nohead = question\n",
    "    answers = q['answers']\n",
    "    candidates = q['candidates']\n",
    "    head = q['head']\n",
    "    question_tokenized, attention_mask = dataset.tokenize_question(question)\n",
    "    scores = model.get_score_ranked(question_tokenized=question_tokenized, attention_mask=attention_mask)\n",
    "    pruning_rels_scores, pruning_rels_torch = torch.topk(scores, 5)\n",
    "    pruning_rels = set()\n",
    "    pruning_rels_threshold = 0.5\n",
    "    for s, p in zip(pruning_rels_scores, pruning_rels_torch):\n",
    "        if s > pruning_rels_threshold:\n",
    "            pruning_rels.add(idx2rel[p.item()])\n",
    "        \n",
    "    my_answer = \"\"\n",
    "    head_nbhood = get2hop(G, head)\n",
    "#     max_intersection = 0\n",
    "    for c in candidates:\n",
    "#         candidate_rels = getAllRelations(head, c)\n",
    "        if c in head_nbhood:\n",
    "            candidate_rels = getAllRelations(head, c)\n",
    "            intersection = pruning_rels.intersection(candidate_rels)\n",
    "            if len(intersection) > 0:\n",
    "                my_answer = c\n",
    "                break\n",
    "    if my_answer == \"\":\n",
    "        my_answer = candidates[0]\n",
    "    if my_answer in answers:\n",
    "        num_correct += 1\n",
    "print('Accuracy is', num_correct/num_for_testing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Algorithm mentioned in paper, sec 4.4.1\n",
    "# slower than the previous one but does not have neighbourhood restriction\n",
    "num_correct = 0\n",
    "for q in tqdm(cws[:num_for_testing]):\n",
    "    question = q['question']\n",
    "    question_nohead = question\n",
    "    answers = q['answers']\n",
    "    candidates = q['candidates']\n",
    "    candidates_scores = q['scores']\n",
    "    head = q['head']\n",
    "    question_tokenized, attention_mask = dataset.tokenize_question(question)\n",
    "    scores = model.get_score_ranked(question_tokenized=question_tokenized, attention_mask=attention_mask)\n",
    "    pruning_rels_scores, pruning_rels_torch = torch.topk(scores, 2)\n",
    "    pruning_rels = set()\n",
    "    pruning_rels_threshold = 0.5 # threshold to consider as written in sec 4.4.1\n",
    "    for s, p in zip(pruning_rels_scores, pruning_rels_torch):\n",
    "        if s > pruning_rels_threshold:\n",
    "            pruning_rels.add(idx2rel[p.item()])\n",
    "    gamma = 1.0\n",
    "    max_score = 0.0\n",
    "    max_score_relscore = 0.0\n",
    "    max_score_answer = \"\"\n",
    "    for score, c in zip(candidates_scores, candidates):\n",
    "        actual_rels = getAllRelations(head, c)\n",
    "        relscore = len(actual_rels.intersection(pruning_rels))\n",
    "        totalscore = score + gamma*relscore\n",
    "        if totalscore > max_score:\n",
    "            max_score = totalscore\n",
    "            max_score_relscore = relscore\n",
    "            max_score_answer = c\n",
    "    is_correct = False\n",
    "    if max_score_answer in answers:\n",
    "        num_correct += 1\n",
    "        is_correct = True    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Accuracy is', num_correct/num_for_testing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "who plays ken barlow in coronation street [m.01_2n]\n",
      "film.film.starring\n",
      "film.performance.actor\n",
      "film.performance.character\n",
      "film.film_character.portrayed_in_films\n",
      "film.actor.film\n",
      "['m.05h48b', 'm.01j_gs', 'm.087k1_', 'm.07ssc', 'm.0dfw0', 'm.0bvxv', 'm.05bmx95', 'm.02rn00y', 'm.0y_m48z', 'm.09lxv9']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'award.award_honor.award_winner', 'award.award_winning_work.awards_won'}"
      ]
     },
     "execution_count": 142,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# following code is just for investigating\n",
    "# i haven't removed it since it might be useful if\n",
    "# someone wants to explore\n",
    "\n",
    "qid=2\n",
    "model.eval()\n",
    "qe = cws[qid]['qe']\n",
    "p1, p2 = torch.topk(model.get_score_ranked(qe), 5)\n",
    "question = cws[qid]['question']\n",
    "print(question)\n",
    "# print(idx2rel[pred])\n",
    "for p in p2:\n",
    "    print(idx2rel[p.item()])\n",
    "candidates = cws[qid]['candidates']\n",
    "print(candidates)\n",
    "# print(cws[qid]['scores'])\n",
    "head = getHead(question)\n",
    "tail = candidates[1]\n",
    "getAllRelations(head, tail)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m.01_2n\n",
      "m.0x15h3d\n",
      "m.01j_gs\n"
     ]
    }
   ],
   "source": [
    "for p in paths:\n",
    "    print(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'tv.tv_program.regular_cast'}"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "getRelationsFromKG('m.01_2n', 'm.0bvv2dt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'government.government_position_held.office_position_or_title'}"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "getRelationsFromKG('m.04j60kh', 'm.02_bcst')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
